{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定制 Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = tvm.tir.const(128, \"int32\")\n",
    "a = te.placeholder((n,), name=\"a\")\n",
    "b = te.placeholder((n,), name=\"b\")\n",
    "c = te.compute((n,), lambda i: a[i] + b[i], name=\"c\")\n",
    "\n",
    "sch = te.create_schedule(c.op)\n",
    "ir = tvm.lower(sch, [a, b, c])\n",
    "ir.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loops = []\n",
    "\n",
    "def find_width8(op):\n",
    "    \"\"\"找出所有范围能被 8 除的 'tir.For' 节点。\"\"\"\n",
    "    if isinstance(op, tvm.tir.For):\n",
    "        if isinstance(op.extent, tvm.tir.IntImm):\n",
    "            if op.extent.value % 8 == 0:\n",
    "                loops.append(op)\n",
    "\n",
    "def vectorize8(op):\n",
    "    \"\"\"Split can vectorize the loops found in `find_width8`.\"\"\"\n",
    "    if op in loops:\n",
    "        extent = op.extent.value\n",
    "        name = op.loop_var.name\n",
    "        lo, li = te.var(name + \".outer\"), te.var(name + \".inner\")\n",
    "        body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})\n",
    "        body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)\n",
    "        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)\n",
    "        return body\n",
    "    return None\n",
    "\n",
    "\n",
    "@tvm.tir.transform.prim_func_pass(opt_level=0)\n",
    "def vectorize(f, mod, ctx):\n",
    "    global loops\n",
    "    tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)\n",
    "    if not loops:\n",
    "        return f\n",
    "    # 最后一个 list 参数表示要转换的节点类型。\n",
    "    # 因此，在这种情况下，只有 `For` 节点会调用 `vectorize8`\n",
    "    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, [\"tir.For\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorize.info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tvm.transform.PassContext(config={\"tir.add_lower_pass\": [(1, vectorize)]}) as ctx:\n",
    "    print(ctx)\n",
    "    tvm.lower(sch, [a, b, c]).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.tir.transform.prim_func_pass(opt_level=1)\n",
    "class TestReplaceFunc:\n",
    "    def __init__(self, new_func):\n",
    "        self.new_func = new_func\n",
    "\n",
    "    def transform_function(self, func, mod, ctx):\n",
    "        # just for demo purposes\n",
    "        # transform func to new_func\n",
    "        return self.new_func\n",
    "    \n",
    "@tvm.tir.transform.prim_func_pass(opt_level=2)\n",
    "def transform(func, mod, ctx):\n",
    "    # my transformations here.\n",
    "    return func\n",
    "\n",
    "function_pass = transform\n",
    "assert isinstance(function_pass, transform.FunctionPass)\n",
    "assert function_pass.info.opt_level == 2\n",
    "\n",
    "# Given a module m, the optimization could be invoked as the following:\n",
    "updated_mod = function_pass(m)\n",
    "# Now constant folding should have been applied to every function in\n",
    "# the provided module m. And the updated module will be returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "def example():\n",
    "    shape = (1, 64, 54, 54)\n",
    "    c_data = np.empty(shape).astype(\"float32\")\n",
    "    c = relay.const(c_data)\n",
    "    weight = relay.var(\"weight\", shape=(64, 64, 3, 3))\n",
    "    x = relay.var(\"x\", relay.TensorType((1, 64, 56, 56), \"float32\"))\n",
    "    conv = relay.nn.conv2d(x, weight, kernel_size=(3, 3))\n",
    "    y = relay.add(c, c)\n",
    "    y = relay.multiply(y, relay.const(2, \"float32\"))\n",
    "    y = relay.add(conv, y)\n",
    "    z = relay.add(y, c)\n",
    "    z1 = relay.add(y, c)\n",
    "    z2 = relay.add(z, z1)\n",
    "    return relay.Function([x, weight], z2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
